Skip to content

Conversation

@Artem-B
Copy link
Member

@Artem-B Artem-B commented Sep 29, 2023

Make it a legal type and plumb through lowering of relevant instructions.

Make it a legal type and plumb through lowering of relevant instructions.
Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually trying out the patch in triton it causes some invalid ptx to get emitted (see comment)

Verified that NVPTX tests pass with ptxas being able to compiler PTX produced by
llc tests.
To make things work consisstently for v4i8, we need to implement other vector
ops.
@github-actions
Copy link

github-actions bot commented Oct 5, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@Artem-B Artem-B requested a review from ThomasRaoux October 5, 2023 22:44
@Artem-B
Copy link
Member Author

Artem-B commented Oct 5, 2023

I still need to test it on live code, though we do not have much code that would end up using v4i8.

The generated PTX checked in llvm/test/CodeGen/NVPTX/i8x4-instructions.ll could use extra scrutiny.

@ThomasRaoux
Copy link
Contributor

I ran the patch on our triton kernels and I don't see any functional problems left.

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it required quite a lot of cases to be handled :(
Thanks for doing this, it solves some of the problems triton had with latest LLVM. Changes look good to me.

@Artem-B
Copy link
Member Author

Artem-B commented Oct 6, 2023

I see one suspicious failure in tensorflow tests. I suspect I've messed something up in v4i8 comparison.

@Artem-B Artem-B requested a review from d0k October 6, 2023 17:43
@Artem-B
Copy link
Member Author

Artem-B commented Oct 6, 2023

I see one suspicious failure in tensorflow tests. I suspect I've messed something up in v4i8 comparison.

Yup, there is a problem:

Successfully custom legalized node
 ... replacing: t10: v4i8 = BUILD_VECTOR Constant:i16<-128>, Constant:i16<-128>, Constant:i16<-128>, Constant:i16<-128>
     with:      t17: v4i8 = bitcast Constant:i32<-128>

@Artem-B
Copy link
Member Author

Artem-B commented Oct 6, 2023

I see one suspicious failure in tensorflow tests. I suspect I've messed something up in v4i8 comparison.

Yup, there is a problem:

Successfully custom legalized node
 ... replacing: t10: v4i8 = BUILD_VECTOR Constant:i16<-128>, Constant:i16<-128>, Constant:i16<-128>, Constant:i16<-128>
     with:      t17: v4i8 = bitcast Constant:i32<-128>

Resolved by 9821e90

@Artem-B
Copy link
Member Author

Artem-B commented Oct 6, 2023

Found another issue. We merge four independent byte loads with align 1 into a 32-bit load, which fails at runtime on misaligned pointers.

%t0 = type { [17 x i8] }

@shared_storage = linkonce_odr local_unnamed_addr addrspace(3) global %t0 undef, align 1

define <4 x i8> @in_v4i8(<4 x i8> %x, <4 x i8> %y) nounwind {
  %v = load <4 x i8>, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @shared_storage to ptr), i64 9), align 1
  ret <4 x i8> %v
}
        mov.u64         %rd1, shared_storage;
        cvta.shared.u64         %rd2, %rd1;
        ld.u32  %r1, [%rd2+9];
        st.param.b32    [func_retval0+0], %r1;
        ret;

@Artem-B
Copy link
Member Author

Artem-B commented Oct 9, 2023

clang-format failure on GitHub is weird -- it just silently exits with an error.
I ran the same command locally and fixed one place it was not happy about.

The buildkite failure somewhere in RISC-V appears to be unrelated.

@abadams
Copy link

abadams commented Oct 15, 2023

I believe this may be causing failures for u/srem. See #69124

Edit: Also causing failures when you sign-extend the result of a <4 x i8> comparison.

justinfargnoli added a commit that referenced this pull request Oct 10, 2024
In [[NVPTX] Improve lowering of
v4i8](cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
([source](#67866 (comment)))

> Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
ericastor pushed a commit to ericastor/llvm-project that referenced this pull request Oct 10, 2024
In [[NVPTX] Improve lowering of
v4i8](llvm@cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
([source](llvm#67866 (comment)))

> Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
In [[NVPTX] Improve lowering of
v4i8](llvm@cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
([source](llvm#67866 (comment)))

> Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
justinfargnoli added a commit that referenced this pull request Oct 31, 2024
Fix
[failure](#110766 (comment))
identified by @akuegel.

---

In [[NVPTX] Improve lowering of
v4i8](cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
(#67866 (comment))

Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Fix
[failure](llvm#110766 (comment))
identified by @akuegel.

---

In [[NVPTX] Improve lowering of
v4i8](llvm@cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
(llvm#67866 (comment))

Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Fix
[failure](llvm#110766 (comment))
identified by @akuegel.

---

In [[NVPTX] Improve lowering of
v4i8](llvm@cbafb6f)
@Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX
instructions. @Artem-B did this because:
(llvm#67866 (comment))

Under the hood byte extraction/insertion ends up as BFI/BFE
instructions, so we may as well do that in PTX, too.
https://godbolt.org/z/Tb3zWbj9b

However, the example that @Artem-B linked was targeting sm_52. On modern
architectures, ptxas uses prmt.b32.
[Example](https://godbolt.org/z/Ye4W1n84o).

Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Artem Belevich (Artem-B)

Changes

Make it a legal type and plumb through lowering of relevant instructions.


Patch is 122.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67866.diff

15 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+31)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTX.h (+12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+14-7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+228-47)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+165-35)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+54-1)
  • (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+1-1)
  • (added) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+1272)
  • (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+12-14)
  • (modified) llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll (+96-429)
  • (modified) llvm/test/CodeGen/NVPTX/vec8.ll (+2-3)
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5d27accdc198c..b7a20c351f5ff 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
   const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
   O << Sym.getName();
 }
+
+void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
+                                     raw_ostream &O, const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  int64_t Imm = MO.getImm();
+
+  switch (Imm) {
+  default:
+    return;
+  case NVPTX::PTXPrmtMode::NONE:
+    break;
+  case NVPTX::PTXPrmtMode::F4E:
+    O << ".f4e";
+    break;
+  case NVPTX::PTXPrmtMode::B4E:
+    O << ".b4e";
+    break;
+  case NVPTX::PTXPrmtMode::RC8:
+    O << ".rc8";
+    break;
+  case NVPTX::PTXPrmtMode::ECL:
+    O << ".ecl";
+    break;
+  case NVPTX::PTXPrmtMode::ECR:
+    O << ".ecr";
+    break;
+  case NVPTX::PTXPrmtMode::RC16:
+    O << ".rc16";
+    break;
+  }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 49ad3f269229d..e6954f861cd10 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                        raw_ostream &O, const char *Modifier = nullptr);
   void printProtoIdent(const MCInst *MI, int OpNum,
                        raw_ostream &O, const char *Modifier = nullptr);
+  void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                     const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 8dc68911fff0c..07ee34968b023 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -181,6 +181,18 @@ enum CmpMode {
   FTZ_FLAG = 0x100
 };
 }
+
+namespace PTXPrmtMode {
+enum PrmtMode {
+  NONE,
+  F4E,
+  B4E,
+  RC8,
+  ECL,
+  ECR,
+  RC16,
+};
+}
 }
 void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
 } // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 0aef2591c6e23..68391cdb6ff17 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -14,6 +14,7 @@
 #include "MCTargetDesc/NVPTXBaseInfo.h"
 #include "NVPTXUtilities.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
   case MVT::v2f16:
   case MVT::v2bf16:
   case MVT::v2i16:
+  case MVT::v4i8:
     return Opcode_i32;
   case MVT::f32:
     return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   // Vector Setting
   unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
   if (SimpleVT.isVector()) {
-    assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
+    assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
+           "Unexpected vector type");
     // v2f16/v2bf16/v2i16 is loaded using ld.b32
     fromTypeWidth = 32;
   }
@@ -1254,6 +1257,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   SDLoc DL(N);
   SDNode *LD;
   SDValue Base, Offset, Addr;
+  EVT OrigType = N->getValueType(0);
 
   EVT EltVT = Mem->getMemoryVT();
   unsigned NumElts = 1;
@@ -1261,12 +1265,15 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     NumElts = EltVT.getVectorNumElements();
     EltVT = EltVT.getVectorElementType();
     // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
-    if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
-        (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
-        (EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
+    if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
+        (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
+        (EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
       assert(NumElts % 2 == 0 && "Vector must have even number of elements");
-      EltVT = N->getValueType(0);
+      EltVT = OrigType;
       NumElts /= 2;
+    } else if (OrigType == MVT::v4i8) {
+      EltVT = OrigType;
+      NumElts = 1;
     }
   }
 
@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   // concept of sign-/zero-extension, so emulate it here by adding an explicit
   // CVT instruction. Ptxas should clean up any redundancies here.
 
-  EVT OrigType = N->getValueType(0);
   LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
 
   if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   MVT ScalarVT = SimpleVT.getScalarType();
   unsigned toTypeWidth = ScalarVT.getSizeInBits();
   if (SimpleVT.isVector()) {
-    assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
+    assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
+           "Unexpected vector type");
     // v2x16 is stored using st.b32
     toTypeWidth = 32;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b24aae4792ce6..36da2e7b40efa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
           llvm_unreachable("Unexpected type");
         }
         NumElts /= 2;
+      } else if (EltVT.getSimpleVT() == MVT::i8 &&
+                 (NumElts % 4 == 0 || NumElts == 3)) {
+        // v*i8 are formally lowered as v4i8
+        EltVT = MVT::v4i8;
+        NumElts = (NumElts + 3) / 4;
       }
       for (unsigned j = 0; j != NumElts; ++j) {
         ValueVTs.push_back(EltVT);
@@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
   addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
+  addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -491,10 +497,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
 
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
+  setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+  // Only logical ops can be done on v4i8 directly, others must be done
+  // elementwise.
+  setOperationAction(
+      {ISD::ADD,       ISD::MUL,        ISD::ABS,        ISD::SMIN,
+       ISD::SMAX,      ISD::UMIN,       ISD::UMAX,       ISD::CTPOP,
+       ISD::CTLZ,      ISD::ADD,        ISD::SUB,        ISD::MUL,
+       ISD::SHL,       ISD::SREM,       ISD::UREM,       ISD::SDIV,
+       ISD::UDIV,      ISD::SRA,        ISD::SRL,        ISD::MULHS,
+       ISD::MULHU,     ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP,
+       ISD::UINT_TO_FP},
+      MVT::v4i8, Expand);
+
   // Operations not directly supported by NVPTX.
-  for (MVT VT :
-       {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
-        MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
+  for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
+                 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
+                 MVT::i32, MVT::i64}) {
     setOperationAction(ISD::SELECT_CC, VT, Expand);
     setOperationAction(ISD::BR_CC, VT, Expand);
   }
@@ -672,7 +694,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
-                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
+                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
+                       ISD::VSELECT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -881,6 +904,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     return "NVPTXISD::FUN_SHFR_CLAMP";
   case NVPTXISD::IMAD:
     return "NVPTXISD::IMAD";
+  case NVPTXISD::BFE:
+    return "NVPTXISD::BFE";
+  case NVPTXISD::BFI:
+    return "NVPTXISD::BFI";
+  case NVPTXISD::PRMT:
+    return "NVPTXISD::PRMT";
   case NVPTXISD::SETP_F16X2:
     return "NVPTXISD::SETP_F16X2";
   case NVPTXISD::Dummy:
@@ -2150,58 +2179,98 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
-// We can init constant f16x2 with a single .b32 move.  Normally it
+// We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
-//        mov.b16         %h1, 0x4000;
-//        mov.b16         %h2, 0x3C00;
-//        mov.b32         %hh2, {%h2, %h1};
 // Instead we want just a constant move:
-//        mov.b32         %hh2, 0x40003C00
-//
-// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
-// generates good SASS in both cases.
+//        mov.b32         %r2, 0x40003C00
 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                SelectionDAG &DAG) const {
   EVT VT = Op->getValueType(0);
-  if (!(Isv2x16VT(VT)))
+  if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
     return Op;
-  APInt E0;
-  APInt E1;
-  if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
-    if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
-          isa<ConstantFPSDNode>(Op->getOperand(1))))
-      return Op;
-
-    E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
-             ->getValueAPF()
-             .bitcastToAPInt();
-    E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
-             ->getValueAPF()
-             .bitcastToAPInt();
-  } else {
-    assert(VT == MVT::v2i16);
-    if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
-          isa<ConstantSDNode>(Op->getOperand(1))))
-      return Op;
 
-    E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
-    E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
+  SDLoc DL(Op);
+
+  if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
+        return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
+               isa<ConstantFPSDNode>(Operand);
+      })) {
+    // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
+    // to optimize calculation of constant parts.
+    if (VT == MVT::v4i8) {
+      SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
+      SDValue E01 = DAG.getNode(
+          NVPTXISD::BFI, DL, MVT::i32,
+          DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
+          DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
+      SDValue E012 =
+          DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                      DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
+                      E01, DAG.getConstant(16, DL, MVT::i32), C8);
+      SDValue E0123 =
+          DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                      DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
+                      E012, DAG.getConstant(24, DL, MVT::i32), C8);
+      return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
+    }
+    return Op;
   }
-  SDValue Const =
-      DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+
+  // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
+  auto GetOperand = [](SDValue Op, int N) -> APInt {
+    const SDValue &Operand = Op->getOperand(N);
+    EVT VT = Op->getValueType(0);
+    if (Operand->isUndef())
+      return APInt(32, 0);
+    APInt Value;
+    if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+      Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
+    else if (VT == MVT::v2i16 || VT == MVT::v4i8)
+      Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
+    else
+      llvm_unreachable("Unsupported type");
+    // i8 values are carried around as i16, so we need to zero out upper bits,
+    // so they do not get in the way of combining individual byte values
+    if (VT == MVT::v4i8)
+      Value = Value.trunc(8);
+    return Value.zext(32);
+  };
+  APInt Value;
+  if (Isv2x16VT(VT)) {
+    Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
+  } else if (VT == MVT::v4i8) {
+    Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
+            GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
+  } else {
+    llvm_unreachable("Unsupported type");
+  }
+  SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
   return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
 }
 
 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
                                                      SelectionDAG &DAG) const {
   SDValue Index = Op->getOperand(1);
+  SDValue Vector = Op->getOperand(0);
+  SDLoc DL(Op);
+  EVT VectorVT = Vector.getValueType();
+
+  if (VectorVT == MVT::v4i8) {
+    SDValue BFE =
+        DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
+                    {Vector,
+                     DAG.getNode(ISD::MUL, DL, MVT::i32,
+                                 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+                                 DAG.getConstant(8, DL, MVT::i32)),
+                     DAG.getConstant(8, DL, MVT::i32)});
+    return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
+  }
+
   // Constant index will be matched by tablegen.
   if (isa<ConstantSDNode>(Index.getNode()))
     return Op;
 
   // Extract individual elements and select one of them.
-  SDValue Vector = Op->getOperand(0);
-  EVT VectorVT = Vector.getValueType();
   assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
   EVT EltVT = VectorVT.getVectorElementType();
 
@@ -2214,6 +2283,49 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
                          ISD::CondCode::SETEQ);
 }
 
+SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDValue Vector = Op->getOperand(0);
+  EVT VectorVT = Vector.getValueType();
+
+  if (VectorVT != MVT::v4i8)
+    return Op;
+  SDLoc DL(Op);
+  SDValue Value = Op->getOperand(1);
+  if (Value->isUndef())
+    return Vector;
+
+  SDValue Index = Op->getOperand(2);
+
+  SDValue BFI =
+      DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                  {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
+                   DAG.getNode(ISD::MUL, DL, MVT::i32,
+                               DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+                               DAG.getConstant(8, DL, MVT::i32)),
+                   DAG.getConstant(8, DL, MVT::i32)});
+  return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
+}
+
+SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDValue V1 = Op.getOperand(0);
+  EVT VectorVT = V1.getValueType();
+  if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
+    return Op;
+
+  // Lower shuffle to PRMT instruction.
+  const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+  SDValue V2 = Op.getOperand(1);
+  uint32_t Selector = 0;
+  for (auto I : llvm::enumerate(SVN->getMask()))
+    Selector |= (I.value() << (I.index() * 4));
+
+  SDLoc DL(Op);
+  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
+                     DAG.getConstant(Selector, DL, MVT::i32),
+                     DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+}
 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
 ///    amount, or
@@ -2464,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
     return LowerEXTRACT_VECTOR_ELT(Op, DAG);
+  case ISD::INSERT_VECTOR_ELT:
+    return LowerINSERT_VECTOR_ELT(Op, DAG);
+  case ISD::VECTOR_SHUFFLE:
+    return LowerVECTOR_SHUFFLE(Op, DAG);
   case ISD::CONCAT_VECTORS:
     return LowerCONCAT_VECTORS(Op, DAG);
   case ISD::STORE:
@@ -2578,9 +2694,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   if (Op.getValueType() == MVT::i1)
     return LowerLOADi1(Op, DAG);
 
-  // v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
+  // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
   // unaligned loads and have to handle it here.
-  if (Isv2x16VT(Op.getValueType())) {
+  EVT VT = Op.getValueType();
+  if (Isv2x16VT(VT) || VT == MVT::v4i8) {
     LoadSDNode *Load = cast<LoadSDNode>(Op);
     EVT MemVT = Load->getMemoryVT();
     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2625,13 +2742,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // stores and have to handle it here.
-  if (Isv2x16VT(VT) &&
+  if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
     return expandUnalignedStore(Store, DAG);
 
   // v2f16, v2bf16 and v2i16 don't need special handling.
-  if (Isv2x16VT(VT))
+  if (Isv2x16VT(VT) || VT == MVT::v4i8)
     return SDValue();
 
   if (VT.isVector())
@@ -2903,7 +3020,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
             LoadVT = MVT::i8;
-          else if (Isv2x16VT(EltVT))
+          else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
             // getLoad needs a vector type, but it can't handle
             // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
@@ -2929,7 +3046,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             if (EltVT == MVT::i1)
               Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
             // v2f16 was loaded as an i32. Now we must bitcast it back.
-            else if (Isv2x16VT(EltVT))
+            else if (EltVT != LoadVT)
               Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
 
             // If a promoted integer type is used, truncate down to the original
@@ -4975,6 +5092,32 @@ static SDValue PerformANDCombine(SDNode *N,
   }
 
   SDValue AExt;
+
+  // Convert BFE-> truncate i16 -> and 255
+  // To just BFE-> truncate i16, as the value already has all the bits in the
+  // right places.
+  if (Val.getOpcode() == ISD::TRUNCATE) {
+    SDValue BFE = Val.getOperand(0);
+    if (BFE.getOpcode() != NVPTXISD::BFE)
+      return SDValue();
+
+    ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
+    if (!BFEBits)
+      return SDValue();
+    uint64_t BFEBitsVal = BFEBits->getZExtValue();
+
+    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+    if (!MaskCnst) {
+      // Not an AND with a constant
+      return SDValue();
+    }
+    uint64_t MaskVal = MaskCnst->getZExtValue();
+
+    if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
+      return SDValue();
+    // If we get here, the AND is unnecessary.  Just replace it with the trunc
+    DCI.CombineTo(N, Val, false);
+  }
   // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
   if (Val.getOpcode() == ISD::ANY_EXTEND) {
     AExt = Val;
@@ -5254,13 +5397,15 @@ static SDValue PerformSETCCCombine(SDNode *N,
 static SDValue PerformEXTRACTCombine(SDNode *N,
                                      Tar...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants